{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d81d63f4-a12f-4aa0-956e-a3f8702a1904",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow_decision_forests as tfdf\n",
    "from sklearn.model_selection import train_test_split\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6c8c1db-8fc9-497e-b5e1-875d8a43635e",
   "metadata": {},
   "source": [
    "##### ❇️ Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "db35588c-df20-4b42-95de-77b7f733a503",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Label classes: ['Adelie', 'Gentoo', 'Chinstrap']\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>species</th>\n",
       "      <th>island</th>\n",
       "      <th>bill_length_mm</th>\n",
       "      <th>bill_depth_mm</th>\n",
       "      <th>flipper_length_mm</th>\n",
       "      <th>body_mass_g</th>\n",
       "      <th>sex</th>\n",
       "      <th>year</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>Torgersen</td>\n",
       "      <td>39.1</td>\n",
       "      <td>18.7</td>\n",
       "      <td>181.0</td>\n",
       "      <td>3750.0</td>\n",
       "      <td>male</td>\n",
       "      <td>2007</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>Torgersen</td>\n",
       "      <td>39.5</td>\n",
       "      <td>17.4</td>\n",
       "      <td>186.0</td>\n",
       "      <td>3800.0</td>\n",
       "      <td>female</td>\n",
       "      <td>2007</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   species     island  bill_length_mm  bill_depth_mm  flipper_length_mm  \\\n",
       "0        0  Torgersen            39.1           18.7              181.0   \n",
       "1        0  Torgersen            39.5           17.4              186.0   \n",
       "\n",
       "   body_mass_g     sex  year  \n",
       "0       3750.0    male  2007  \n",
       "1       3800.0  female  2007  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load and prepare data; Our task would be to predict specie of the penguin\n",
    "dataset_df = pd.read_csv('penguins.csv')\n",
    "label = \"species\"\n",
    "classes = dataset_df[label].unique().tolist()\n",
    "print(f\"Label classes: {classes}\")\n",
    "dataset_df[label] = dataset_df[label].map(classes.index)\n",
    "dataset_df.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6cb0f048-989f-479b-8da4-7a9b4f3c87b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split data into train/test\n",
    "train_ds_pd, test_ds_pd = train_test_split(dataset_df, test_size=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d280fe66-5228-4564-b1a4-082ee9bbfae8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Converting pandas dataframes to tensorflow datasets\n",
    "train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label=label)\n",
    "test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_ds_pd, label=label)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e11b12bb-c619-4136-aacd-5a632fa5493d",
   "metadata": {},
   "source": [
    "##### ❇️ Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2472e857-ef42-4adc-970c-f57cb397fc59",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load and train model\n",
    "model = tfdf.keras.RandomForestModel()\n",
    "model.fit(x=train_ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2be95d1-17ba-4cd6-a3e2-58aeef052dc0",
   "metadata": {},
   "source": [
    "##### ❇️ Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "39765bdf-b0af-453a-82f2-c8632ae31455",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 0s 339ms/step - loss: 0.0000e+00 - accuracy: 0.9615\n",
      "loss: 0.0000\n",
      "accuracy: 0.9615\n"
     ]
    }
   ],
   "source": [
    "model.compile(metrics=[\"accuracy\"])\n",
    "evaluation = model.evaluate(test_ds, return_dict=True)\n",
    "for name, value in evaluation.items():\n",
    "    print(f\"{name}: {value:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13e9eb33-2782-4198-a3e1-ca0a3d3a6adc",
   "metadata": {},
   "source": [
    "##### ❇️ Save model; ready to be served using tf-serving"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e22ac43b-4925-4c5f-9618-cd1f1f0f327a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(\"/path_to_save_model_directory\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b1d3d79-9a8b-48a5-82b2-6ad40056e09f",
   "metadata": {},
   "source": [
    "##### ❇️ Hope you enjoyed reading!! 📖 \n",
    "##### ❇️ follow → @akshay_pachaar  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2511c8e5-391d-4ab4-be8d-3892093e133b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f4e802a-99e1-499c-9063-b25f2c6fdb68",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15e91cef-d4e9-40af-b61b-814b3a7f5c0f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e2a41e7-594d-4d54-9a6f-47ead8adf871",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1e815d2a-ae92-4e91-9856-d0a0afcb3508",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78456c13-b03e-497e-9e3e-1b7b48565cdc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
